import tensorflow as tf
import numpy as np

np.random.seed(1234)
import os
import time
import datetime
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from builddata import *
from model_R_MeN_TripleCls_CNN import RMeN

# Parameters
# ==================================================
parser = ArgumentParser("RMeN", formatter_class=ArgumentDefaultsHelpFormatter, conflict_handler='resolve')

parser.add_argument("--data", default="./data/", help="Data sources.")
parser.add_argument("--run_folder", default="../", help="Data sources.")
parser.add_argument("--name", default="WN11", help="Name of the dataset.")
parser.add_argument("--embedding_dim", default=50, type=int, help="Dimensionality of character embedding")
parser.add_argument("--learning_rate", default=0.0001, type=float, help="Learning rate")
parser.add_argument("--batch_size", default=8, type=int, help="Batch Size")
parser.add_argument("--num_epochs", default=50, type=int, help="Number of training epochs")
parser.add_argument("--saveStep", default=1, type=int, help="")
parser.add_argument("--neg_ratio", default=1.0, type=float, help="Number of negative triples generated by positive")
parser.add_argument("--allow_soft_placement", default=True, type=bool, help="Allow device soft device placement")
parser.add_argument("--log_device_placement", default=False, type=bool, help="Log placement of ops on devices")
parser.add_argument("--model_name", default='WN11', help="")
parser.add_argument("--dropout_keep_prob", default=0.5, type=float, help="Dropout keep probability")
parser.add_argument("--num_heads", default=3, type=int, help="Number of attention heads. 1 2 4")
parser.add_argument("--memory_slots", default=1, type=int, help="Number of memory slots. 1 2 4")
parser.add_argument("--head_size", default=50, type=int, help="")
parser.add_argument("--gate_style", default='memory', help="unit,memory")
parser.add_argument("--attention_mlp_layers", default=2, type=int, help="2 3 4")
parser.add_argument("--use_pos", default=1, type=int, help="1 when using positional embeddings. Otherwise.")
parser.add_argument("--num_filters", default=20, type=int, help="Number of filters per filter size")

args = parser.parse_args()
print(args)

# Load data
print("Loading data...")

train, valid, test, words_indexes, indexes_words, \
headTailSelector, entity2id, id2entity, relation2id, id2relation = build_data(path=args.data, name=args.name)
data_size = len(train)
train_batch = Batch_Loader(train, words_indexes, indexes_words, headTailSelector, \
                           entity2id, id2entity, relation2id, id2relation, batch_size=args.batch_size,
                           neg_ratio=args.neg_ratio)

entity_array = np.array(list(train_batch.indexes_ents.keys()))

print("Using pre-trained model.")
lstEmbed = np.empty([len(words_indexes), args.embedding_dim]).astype(np.float32)
initEnt, initRel = init_norm_Vector(args.data + args.name + '/relation2vec' + str(args.embedding_dim) + '.init',
                                    args.data + args.name + '/entity2vec' + str(args.embedding_dim) + '.init', args.embedding_dim)

for _word in words_indexes:
    if _word in relation2id:
        index = relation2id[_word]
        _ind = words_indexes[_word]
        lstEmbed[_ind] = initRel[index]
    elif _word in entity2id:
        index = entity2id[_word]
        _ind = words_indexes[_word]
        lstEmbed[_ind] = initEnt[index]
    else:
        print('*****************Error********************!')
        break

lstEmbed = np.array(lstEmbed, dtype=np.float32)

assert len(words_indexes) % (len(entity2id) + len(relation2id)) == 0

#######################
x_valid = []
y_valid = []
with open(args.data + '/' + args.name + '/valid.txt') as f:
    lines = f.readlines()
for _, line in enumerate(lines):
    sub, obj, rel, val = parse_line(line)
    x_valid.append([words_indexes[sub], words_indexes[rel], words_indexes[obj]])
    y_valid.append(val)
x_valid = np.array(x_valid).astype(np.int32)
y_valid = np.array(y_valid).astype(np.float32)

x_test = []
y_test = []
with open(args.data + '/' + args.name + '/test.txt') as f:
    lines = f.readlines()
for _, line in enumerate(lines):
    sub, obj, rel, val = parse_line(line)
    x_test.append([words_indexes[sub], words_indexes[rel], words_indexes[obj]])
    y_test.append(val)
x_test = np.array(x_test).astype(np.int32)
y_test = np.array(y_test).astype(np.float32)

print(len(x_test), len(x_valid))

print("Loading data... finished!")

# lstdictR = {}
# for i in range(len(x_test)):
#     if x_test[i][1] not in lstdictR:
#         lstdictR[x_test[i][1]] = []
#     lstdictR[x_test[i][1]].append([x_test[i], y_test[i]])
#
# for tmp in lstdictR:
#     print(tmp, train_batch.indexes_rels[tmp], len(lstdictR[tmp]))
#
# print(sum([len(lstdictR[i]) for i in lstdictR]), len(x_test))
# print(lstdictR.keys())

# Training
# ==================================================
with tf.Graph().as_default():
    tf.compat.v1.set_random_seed(1234)
    session_conf = tf.compat.v1.ConfigProto(allow_soft_placement=args.allow_soft_placement,
                                  log_device_placement=args.log_device_placement)
    session_conf.gpu_options.allow_growth = True
    sess = tf.compat.v1.Session(config=session_conf)
    with sess.as_default():
        global_step = tf.Variable(0, name="global_step", trainable=False)
        r_men = RMeN(
            vocab_size=len(words_indexes),
            batch_size=args.batch_size * (int(args.neg_ratio) + 1),
            initialization=lstEmbed,
            embedding_size=args.embedding_dim,
            num_heads=args.num_heads,
            mem_slots=args.memory_slots,
            head_size=args.head_size,
            attention_mlp_layers=args.attention_mlp_layers,
            use_pos=args.use_pos,
            num_filters=args.num_filters,
        )

        #Optimizer
        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=args.learning_rate)
        # optimizer = tf.train.RMSPropOptimizer(learning_rate=args.learning_rate)
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        grads_and_vars = optimizer.compute_gradients(r_men.loss)
        train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

        out_dir = os.path.abspath(os.path.join(args.run_folder, "runs_RMeN_TripleCls_ConvKB_max_pool", args.model_name))
        print("Writing to {}\n".format(out_dir))

        # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
        checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "model")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        # Initialize all variables
        sess.run(tf.global_variables_initializer())

        def train_step(x_batch, y_batch):
            """
            A single training step
            """
            feed_dict = {
                r_men.input_x: x_batch,
                r_men.input_y: y_batch,
                r_men.dropout_keep_prob: args.dropout_keep_prob
            }
            _, step, loss = sess.run([train_op, global_step, r_men.loss], feed_dict)

            return loss

        # Predict function to predict scores for test data
        def predict(x_batch, y_batch):
            feed_dict = {
                r_men.input_x: x_batch,
                r_men.input_y: y_batch,
                r_men.dropout_keep_prob: 1.0
            }
            scores = sess.run([r_men.predictions], feed_dict)
            return scores

        def test_classification(x_batch_dev, y_batch_dev, x_batch_test, y_batch_test):
            new_x_batch = np.tile(x_batch_test, (1, 1))
            new_y_batch = np.tile(y_batch_test, (1, 1))

            while len(new_x_batch) % ((int(args.neg_ratio) + 1) * args.batch_size) != 0:
                new_x_batch = np.append(new_x_batch, [x_batch_test[0]], axis=0)
                new_y_batch = np.append(new_y_batch, [y_batch_test[0]], axis=0)

            results = []
            listIndexes = range(0, len(new_x_batch), (int(args.neg_ratio) + 1) * args.batch_size)
            for tmpIndex in range(len(listIndexes) - 1):
                results = np.append(results, predict(
                    new_x_batch[listIndexes[tmpIndex]:listIndexes[tmpIndex + 1]],
                    new_y_batch[listIndexes[tmpIndex]:listIndexes[tmpIndex + 1]]))
            results = np.append(results,
                                predict(new_x_batch[listIndexes[-1]:], new_y_batch[listIndexes[-1]:]))
            results_test = results[:len(x_batch_test)]
            results_test = np.reshape(results_test, [len(results_test), 1])

            new_x_batch = np.tile(x_batch_dev, (1, 1))
            new_y_batch = np.tile(y_batch_dev, (1, 1))

            while len(new_x_batch) % (2 * args.batch_size) != 0:
                new_x_batch = np.append(new_x_batch, [x_batch_dev[0]], axis=0)
                new_y_batch = np.append(new_y_batch, [y_batch_dev[0]], axis=0)

            results = []
            listIndexes = range(0, len(new_x_batch), 2 * args.batch_size)
            for tmpIndex in range(len(listIndexes) - 1):
                results = np.append(results, predict(
                    new_x_batch[listIndexes[tmpIndex]:listIndexes[tmpIndex + 1]],
                    new_y_batch[listIndexes[tmpIndex]:listIndexes[tmpIndex + 1]]))
            results = np.append(results,
                                predict(new_x_batch[listIndexes[-1]:], new_y_batch[listIndexes[-1]:]))
            results = results[:len(x_batch_dev)]
            results = np.reshape(results, [len(results), 1])

            dev = {}
            for i in range(len(x_batch_dev)):
                if x_batch_dev[i][1] not in dev:
                    dev[x_batch_dev[i][1]] = []
                dev[x_batch_dev[i][1]].append([results[i][0], y_batch_dev[i][0]])

            relThresholds = {}
            for rel in dev:
                _max = np.max(np.array(dev[rel])[:, 0])
                _min = np.min(np.array(dev[rel])[:, 0])

                # print(rel, len(dev[rel]), _max, _min)

                bestThreshold = _min
                bestAcc = 0.0
                bestCount = 0

                for threshold in np.arange(_min, _max, 0.01):
                    count = 0
                    for tmp in dev[rel]:
                        label = -1
                        if tmp[0] < threshold:
                            label = 1
                        if label == tmp[1]:
                            count += 1
                    acc = count * 1.0 / len(dev[rel])
                    if bestAcc < acc:
                        bestAcc = acc
                        bestThreshold = threshold
                        bestCount = count
                # print("Best threshold on dev set for rel id ", rel, bestThreshold)
                relThresholds[rel] = bestThreshold

            count = 0
            for i in range(len(x_batch_dev)):
                label = -1
                if results[i][0] < relThresholds[x_batch_dev[i][1]]:
                    label = 1
                if label == y_batch_dev[i][0]:
                    count += 1
            print("Accuracy on validation set:", count * 1.0 / len(x_batch_dev))

            count_test = 0
            for i in range(len(x_batch_test)):
                label = -1
                if results_test[i][0] < relThresholds[x_batch_test[i][1]]:
                    label = 1
                if label == y_batch_test[i][0]:
                    count_test += 1
            print("Accuracy on test set:", count_test * 1.0 / len(x_batch_test))
            return count * 1.0 / len(x_batch_dev), count_test * 1.0 / len(x_batch_test)

        wri = open(checkpoint_prefix + '.cls.' + '.txt', 'w')

        dev_acc_lst = []
        test_acc_lst = []
        num_batches_per_epoch = int((data_size - 1) / args.batch_size) + 1
        for epoch in range(1, args.num_epochs+1):
            loss = 0
            for _ in range(num_batches_per_epoch):
                x_batch, y_batch = train_batch()
                loss += train_step(x_batch, y_batch)
                # current_step = tf.compat.v1.train.global_step(sess, global_step)
            print(loss)

            dev_acc, test_acc = test_classification(x_valid, y_valid, x_test, y_test)
            dev_acc_lst.append(dev_acc)
            test_acc_lst.append(test_acc)
            wri.write("epoch " + str(epoch) + ": " + str(dev_acc) + "\n")

            # ####for each relation
            # for ii in lstdictR:
            #     x_test_rel = []
            #     y_test_rel = []
            #     for tmp in lstdictR[ii]:
            #         x_test_rel.append(tmp[0])
            #         y_test_rel.append(tmp[1])
            #     x_test_rel = np.array(x_test_rel).astype(np.int32)
            #     y_test_rel = np.array(y_test_rel).astype(np.float32)
            #
            #     dev_acc_rel, test_acc_rel = test_classification(x_valid, y_valid, x_test_rel, y_test_rel)
            #
            #     wri.write("epoch " + str(epoch) + ": " + str(dev_acc_rel) + " " + str(test_acc_rel) + "\n")


        index_valid_max = np.argmax(dev_acc_lst)
        print("Best acc in valid:", dev_acc_lst[index_valid_max])
        print('Acc in test:', test_acc_lst[index_valid_max])

        wri.write("\n--------------------------\n")
        wri.write("\nBest acc in valid at epoch " + str(index_valid_max) + ": " + str(dev_acc_lst[index_valid_max]) + "\n")
        wri.write("\nAcc in test: " + str(test_acc_lst[index_valid_max]) + "\n")
        wri.close()